tensorflow中的动态维度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import tensorflow as tf
sess = tf.Session()
q = tf.placeholder(tf.float32, shape=(None, 1024))

# 采用未知维度(None),初始化tensor
dim_none = tf.shape(a)[0]
p = tf.zeros([dim_none], tf.float32)

# 采用未知维度(None),初始化variable
v = tf.Variable(tf.ones([dim_none,5])) # 这里会报错

# run
rand_array = np.random.rand(10, 1024)
p_value = sess.run(p, feed_dict={q: rand_array})
p_value.shape # (10, 1024)